{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "csv.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "toc_visible": true
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DweYe9FcbMK_"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "AVV2e0XKbJeX",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sUtoed20cRJJ"
      },
      "source": [
        "# tf.data を使って CSV をロードする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1ap_W4aQcgNT"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/load_data/csv\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/ja/tutorials/load_data/csv.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/ja/tutorials/load_data/csv.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RSywPQ2n736s"
      },
      "source": [
        "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "C-3Xbt0FfGfs"
      },
      "source": [
        "このチュートリアルでは、CSV データを `tf.data.Dataset` にロードする手法の例を示します。\n",
        "\n",
        "このチュートリアルで使われているデータはタイタニック号の乗客リストから取られたものです。乗客が生き残る可能性を、年齢、性別、チケットの等級、そして乗客が一人で旅行しているか否かといった特性から予測することを試みます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fgZ9gjmPfSnK"
      },
      "source": [
        "## 設定"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "I4dwMQVQMQWD",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  !pip install tf-nightly-2.0-preview\n",
        "except Exception:\n",
        "  pass\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "baYFZMW_bJHh",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ncf5t6tgL5ZI",
        "colab": {}
      },
      "source": [
        "TRAIN_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\"\n",
        "TEST_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/eval.csv\"\n",
        "\n",
        "train_file_path = tf.keras.utils.get_file(\"train.csv\", TRAIN_DATA_URL)\n",
        "test_file_path = tf.keras.utils.get_file(\"eval.csv\", TEST_DATA_URL)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "4ONE94qulk6S",
        "colab": {}
      },
      "source": [
        "# numpy の値を読みやすくする\n",
        "np.set_printoptions(precision=3, suppress=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Wuqj601Qw0Ml"
      },
      "source": [
        "## データのロード\n",
        "\n",
        "それではいつものように、扱っている CSV ファイルの先頭を見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "54Dv7mCrf9Yw",
        "colab": {}
      },
      "source": [
        "!head {train_file_path}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YOYKQKmMj3D6"
      },
      "source": [
        "ご覧のように、CSV の列にはラベルが付いています。後ほど必要になるので、ファイルから読み出しておきましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "v0sLG216MtwT",
        "colab": {}
      },
      "source": [
        "# 入力ファイル中の CSV 列\n",
        "with open(train_file_path, 'r') as f:\n",
        "    names_row = f.readline()\n",
        "\n",
        "\n",
        "CSV_COLUMNS = names_row.rstrip('\\n').split(',')\n",
        "print(CSV_COLUMNS)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZS-bt1LvWn2x"
      },
      "source": [
        " データセットコンストラクタはこれらのラベルを自動的にピックアップします。\n",
        "\n",
        "使用するファイルの1行目に列名がない場合、`make_csv_dataset` 関数の `column_names` 引数に文字列のリストとして渡します。\n",
        "\n",
        "```python\n",
        "\n",
        "CSV_COLUMNS = ['survived', 'sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']\n",
        "\n",
        "dataset = tf.data.experimental.make_csv_dataset(\n",
        "     ...,\n",
        "     column_names=CSV_COLUMNS,\n",
        "     ...)\n",
        "  \n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gZfhoX7bR9u4"
      },
      "source": [
        "この例では使用可能な列をすべて使うことになります。データセットから列を除く必要がある場合には、使用したい列だけを含むリストを作り、コンストラクタの（オプションである）`select_columns` 引数として渡します。\n",
        "\n",
        "\n",
        "```python\n",
        "\n",
        "drop_columns = ['fare', 'embark_town']\n",
        "columns_to_use = [col for col in CSV_COLUMNS if col not in drop_columns]\n",
        "\n",
        "dataset = tf.data.experimental.make_csv_dataset(\n",
        "  ...,\n",
        "  select_columns = columns_to_use, \n",
        "  ...)\n",
        "\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "67mfwr4v-mN_"
      },
      "source": [
        "各サンプルのラベルとなる列を特定し、それが何であるかを示す必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "iXROZm5f3V4E",
        "colab": {}
      },
      "source": [
        "LABELS = [0, 1]\n",
        "LABEL_COLUMN = 'survived'\n",
        "\n",
        "FEATURE_COLUMNS = [column for column in CSV_COLUMNS if column != LABEL_COLUMN]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "t4N-plO4tDXd"
      },
      "source": [
        "コンストラクタの引数の値が揃ったので、ファイルから CSV データを読み込みデータセットを作ることにしましょう。\n",
        "\n",
        "（完全なドキュメントは、`tf.data.experimental.make_csv_dataset` を参照してください）"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Co7UJ7gpNADC",
        "colab": {}
      },
      "source": [
        "def get_dataset(file_path):\n",
        "  dataset = tf.data.experimental.make_csv_dataset(\n",
        "      file_path,\n",
        "      batch_size=12, # 見やすく表示するために意図して小さく設定しています\n",
        "      label_name=LABEL_COLUMN,\n",
        "      na_value=\"?\",\n",
        "      num_epochs=1,\n",
        "      ignore_errors=True)\n",
        "  return dataset\n",
        "\n",
        "raw_train_data = get_dataset(train_file_path)\n",
        "raw_test_data = get_dataset(test_file_path)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vHUQFKoQI6G7"
      },
      "source": [
        "データセットを構成する要素は、(複数のサンプル, 複数のラベル)の形のタプルとして表されるバッチです。サンプル中のデータは（行ベースのテンソルではなく）列ベースのテンソルとして構成され、それぞれはバッチサイズ（このケースでは12個）の要素が含まれます。\n",
        "\n",
        "実際に見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qWtFYtwXIeuj",
        "colab": {}
      },
      "source": [
        "examples, labels = next(iter(raw_train_data)) # 最初のバッチのみ\n",
        "print(\"EXAMPLES: \\n\", examples, \"\\n\")\n",
        "print(\"LABELS: \\n\", labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9cryz31lxs3e"
      },
      "source": [
        "## データの前処理"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tSyrkSQwYHKi"
      },
      "source": [
        "### カテゴリデータ\n",
        "\n",
        "この CSV データ中のいくつかの列はカテゴリ列です。つまり、その中身は、限られた選択肢の中のひとつである必要があります。\n",
        "\n",
        "この CSV では、これらの選択肢はテキストとして表現されています。このテキストは、モデルの訓練を行えるように、数字に変換する必要があります。これをやりやすくするため、カテゴリ列のリストとその選択肢のリストを作成する必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "mWDniduKMw-C",
        "colab": {}
      },
      "source": [
        "CATEGORIES = {\n",
        "    'sex': ['male', 'female'],\n",
        "    'class' : ['First', 'Second', 'Third'],\n",
        "    'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],\n",
        "    'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],\n",
        "    'alone' : ['y', 'n']\n",
        "}\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8Ii0YWsoKBVx"
      },
      "source": [
        "カテゴリ値のテンソルを受け取り、それを値の名前のリストとマッチングして、さらにワンホット・エンコーディングを行う関数を書きます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "bP02_BflkDbv",
        "colab": {}
      },
      "source": [
        "def process_categorical_data(data, categories):\n",
        "  \"\"\"カテゴリ値を表すワンホット・エンコーディングされたテンソルを返す\"\"\"\n",
        "  \n",
        "  # 最初の ' ' を取り除く\n",
        "  data = tf.strings.regex_replace(data, '^ ', '')\n",
        "  # 最後の '.' を取り除く\n",
        "  data = tf.strings.regex_replace(data, r'\\.$', '')\n",
        "  \n",
        "  # ワンホット・エンコーディング\n",
        "  # data を1次元（リスト）から2次元（要素が1個のリストのリスト）にリシェープ\n",
        "  data = tf.reshape(data, [-1, 1])\n",
        "  # それぞれの要素について、カテゴリ数の長さの真偽値のリストで、\n",
        "  # 要素とカテゴリのラベルが一致したところが True となるものを作成\n",
        "  data = categories == data\n",
        "  # 真偽値を浮動小数点数にキャスト\n",
        "  data = tf.cast(data, tf.float32)\n",
        "  \n",
        "  # エンコーディング全体を次の1行に収めることもできる：\n",
        "  # data = tf.cast(categories == tf.reshape(data, [-1, 1]), tf.float32)\n",
        "  return data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "To2qbBGGMO1D"
      },
      "source": [
        "この処理を可視化するため、最初のバッチからカテゴリ列のテンソル1つを取り出し、処理を行い、前後の状態を示します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ds7MOLMkK2Gf",
        "colab": {}
      },
      "source": [
        "class_tensor = examples['class']\n",
        "class_tensor"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HdDUSgpoTKfA",
        "colab": {}
      },
      "source": [
        "class_categories = CATEGORIES['class']\n",
        "class_categories"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "yHQeR47_ObpT",
        "colab": {}
      },
      "source": [
        "processed_class = process_categorical_data(class_tensor, class_categories)\n",
        "processed_class"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ACkc_cCaTuos"
      },
      "source": [
        "2つの入力の長さと、出力の形状の関係に注目してください。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "gvvXM8m0T00O",
        "colab": {}
      },
      "source": [
        "print(\"Size of batch: \", len(class_tensor.numpy()))\n",
        "print(\"Number of category labels: \", len(class_categories))\n",
        "print(\"Shape of one-hot encoded tensor: \", processed_class.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9AsbaFmCeJtF"
      },
      "source": [
        "### 連続データ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "o2maE8d2ijsq"
      },
      "source": [
        "連続データは値が0と1の間にになるように標準化する必要があります。これを行うために、それぞれの値を、1を列値の平均の2倍で割ったものを掛ける関数を書きます。\n",
        "\n",
        "この関数は、データの2次元のテンソルへのリシェープも行います。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "IwGOy61lkQw-",
        "colab": {}
      },
      "source": [
        "def process_continuous_data(data, mean):\n",
        "  # data の標準化\n",
        "  data = tf.cast(data, tf.float32) * 1/(2*mean)\n",
        "  return tf.reshape(data, [-1, 1])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0Yh8R7BujTAu"
      },
      "source": [
        "この計算を行うためには、列値の平均が必要です。現実には、この値を計算する必要があるのは明らかですが、この例のために値を示します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "iNE_mTJqegGQ",
        "colab": {}
      },
      "source": [
        "MEANS = {\n",
        "    'age' : 29.631308,\n",
        "    'n_siblings_spouses' : 0.545455,\n",
        "    'parch' : 0.379585,\n",
        "    'fare' : 34.385399\n",
        "}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "raZtRlmaj-A5"
      },
      "source": [
        "前と同様に、この関数が実際に何をしているかを見るため、連続値のテンソルを1つ取り、処理前と処理後を見てみます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "G-t_RSBrM2Vm",
        "colab": {}
      },
      "source": [
        "age_tensor = examples['age']\n",
        "age_tensor"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "M9lMLaEsjq3K",
        "colab": {}
      },
      "source": [
        "process_continuous_data(age_tensor, MEANS['age'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kPWkC4_1l3IG"
      },
      "source": [
        "### データの前処理"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "jIvyqVAXmsN4"
      },
      "source": [
        "これらの前処理のタスクを1つの関数にまとめ、データセット内のバッチにマッピングできるようにします。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rMxEHN0SNPkC",
        "colab": {}
      },
      "source": [
        "def preprocess(features, labels):\n",
        "  \n",
        "  # カテゴリ特徴量の処理\n",
        "  for feature in CATEGORIES.keys():\n",
        "    features[feature] = process_categorical_data(features[feature],\n",
        "                                                 CATEGORIES[feature])\n",
        "\n",
        "  # 連続特徴量の処理\n",
        "  for feature in MEANS.keys():\n",
        "    features[feature] = process_continuous_data(features[feature],\n",
        "                                                MEANS[feature])\n",
        "  \n",
        "  # 特徴量を1つのテンソルに組み立てる\n",
        "  features = tf.concat([features[column] for column in FEATURE_COLUMNS], 1)\n",
        "  \n",
        "  return features, labels\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "34K5ESbYnkg4"
      },
      "source": [
        "次に、 `tf.Dataset.map` 関数を使って適用し、過学習を防ぐためにデータセットをシャッフルします。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "7M5km0f_1pVp",
        "colab": {}
      },
      "source": [
        "train_data = raw_train_data.map(preprocess).shuffle(500)\n",
        "test_data = raw_test_data.map(preprocess)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "IQOWatzRr2aF"
      },
      "source": [
        "サンプル1個がどうなっているか見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Gc1o9ZpCsGGM",
        "colab": {}
      },
      "source": [
        "examples, labels = next(iter(train_data))\n",
        "\n",
        "examples, labels"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aJnOromrse57"
      },
      "source": [
        "このサンプルは、（バッチサイズである）12個のアイテムをもつ2次元の配列からできています。アイテムそれぞれは、元の CSV ファイルの1行を表しています。ラベルは12個の値をもつ1次元のテンソルです。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DlF_omQqtnOP"
      },
      "source": [
        "## モデルの構築"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lQoFh16LxtT_"
      },
      "source": [
        "この例では、[Keras Functional API](https://www.tensorflow.org/guide/keras/functional) を使用し、単純なモデルを構築するために `get_model` コンストラクタでラッピングしています。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "JDM3FIgHNCW3",
        "colab": {}
      },
      "source": [
        "def get_model(input_dim, hidden_units=[100]):\n",
        "  \"\"\"複数の層を持つ Keras モデルを作成\n",
        "\n",
        "  引数:\n",
        "    input_dim: (int) バッチ中のアイテムの形状\n",
        "    labels_dim: (int) ラベルの形状\n",
        "    hidden_units: [int] DNN の層のサイズ（入力層が先）\n",
        "    learning_rate: (float) オプティマイザの学習率\n",
        "    \n",
        "  戻り値:\n",
        "    Keras モデル\n",
        "  \"\"\"\n",
        "\n",
        "  inputs = tf.keras.Input(shape=(input_dim,))\n",
        "  x = inputs\n",
        "\n",
        "  for units in hidden_units:\n",
        "    x = tf.keras.layers.Dense(units, activation='relu')(x)\n",
        "  outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n",
        "\n",
        "  model = tf.keras.Model(inputs, outputs)\n",
        " \n",
        "  return model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ce9PRb_LzFpm"
      },
      "source": [
        "`get_model` コンストラクタは入力データの形状（バッチサイズを除く）を知っている必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qX-DU34ZuKJX",
        "colab": {}
      },
      "source": [
        "input_shape, output_shape = train_data.output_shapes\n",
        "\n",
        "input_dimension = input_shape.dims[1] # [0] はバッチサイズ"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hPdtI2ie0lEZ"
      },
      "source": [
        "## 訓練、評価、そして予測"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8gvw1RE9zXkD"
      },
      "source": [
        "これでモデルをインスタンス化し、訓練することができます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Q_nm28IzNDTO",
        "colab": {}
      },
      "source": [
        "model = get_model(input_dimension)\n",
        "model.compile(\n",
        "    loss='binary_crossentropy',\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])\n",
        "\n",
        "model.fit(train_data, epochs=20)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QyDMgBurzqQo"
      },
      "source": [
        "モデルの訓練が終わったら、`test_data` データセットでの正解率をチェックできます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "eB3R3ViVONOp",
        "colab": {}
      },
      "source": [
        "test_loss, test_accuracy = model.evaluate(test_data)\n",
        "\n",
        "print('\\n\\nTest Loss {}, Test Accuracy {}'.format(test_loss, test_accuracy))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sTrn_pD90gdJ"
      },
      "source": [
        "単一のバッチ、または、バッチからなるデータセットのラベルを推論する場合には、`tf.keras.Model.predict` を使います。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Qwcx74F3ojqe",
        "colab": {}
      },
      "source": [
        "predictions = model.predict(test_data)\n",
        "\n",
        "# 結果のいくつかを表示\n",
        "for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):\n",
        "  print(\"Predicted survival: {:.2%}\".format(prediction[0]),\n",
        "        \" | Actual outcome: \",\n",
        "        (\"SURVIVED\" if bool(survived) else \"DIED\"))\n"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
